import numpy as np
from PIL import Image
import random
import torch
import torch.nn as nn
from copy import deepcopy
import torchvision
import copy
from torchvision import transforms
from torch.nn import functional as F
from sklearn.metrics import roc_auc_score, accuracy_score
from numpy.linalg import norm
import os
import gc
from tqdm import tqdm

from torch.utils.data import Subset

# Model Dataset
from .detector.datasets import ModelDataset

# Loading data loaders
from .data.loaders import get_ood_loader, get_cls_loader

# Loading eval functions
from .eval.eval import evaluate

# Loading constants
from .constants import CLEAN_ROOT_DICT, BAD_ROOT_DICT, NORM_MEAN, NORM_STD
from .constants import num_classes as num_classes_dict

# visualization
from .visualization import visualize_samples
from .attacks.pgd import PGD as Attack

from detector.scores import msp_ood_diff as score_function

from .validate import get_models_scores, get_acc_on_models_scores
from BAD.visualization import visualize_samples

def get_dataloader():
    dataloader = get_ood_loader(in_dataset=source_dataset, out_dataset=out_dataset, sample_num=sample_num, sample=True, batch_size=batch_size)
    print("Size of dataset:", len(dataloader.dataset))
    return dataloader


arch = 'resnet18'
dataset = 'cifar10'
models_root = f'~/home/models_datasets/cifar10/{arch}'

if arch == 'preact':
    from BAD.models.loaders import load_preact as model_loader
elif arch == 'resnet':
    from BAD.models.loaders import load_resnet as model_loader
else:
    raise NotImplementedError("This architecture is not supported")

final_model_loader = lambda x, meta_data: model_loader(x,
                                                    num_classes=num_classes,
                                                    mean=NORM_MEAN[source_dataset],
                                                    std=NORM_STD[source_dataset],
                                                    normalize=False,
                                                    meta_data=meta_data)


model_dataset = ModelDataset(models_root, final_model_loader, version=None, more_clean=True, balanced=True)
attack = Attack(model)


dataloader = get_dataloader()
print(len(dataloader.dataset))
visualize_samples(dataloader, 10)

print(get_acc_on_models_scores(model_dataset,
                               score_function,
                               dataloader_func=get_dataloader))

